/* ************************************************************************
 * Copyright (c) 2018 Advanced Micro Devices, Inc.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 *
 * ************************************************************************ */

#include "../../utils/def.hpp"
#include "preconditioner_ai.hpp"
#include "../solver.hpp"

#include "../../base/local_matrix.hpp"

#include "../../base/local_vector.hpp"

#include "../../utils/log.hpp"

#include <math.h>
#include <complex>

namespace rocalution {

template <class OperatorType, class VectorType, typename ValueType>
AIChebyshev<OperatorType, VectorType, ValueType>::AIChebyshev()
{
    log_debug(this, "AIChebyshev::AIChebyshev()", "default constructor");

    this->p_          = 0;
    this->lambda_min_ = static_cast<ValueType>(0);
    this->lambda_max_ = static_cast<ValueType>(0);
}

template <class OperatorType, class VectorType, typename ValueType>
AIChebyshev<OperatorType, VectorType, ValueType>::~AIChebyshev()
{
    log_debug(this, "AIChebyshev::~AIChebyshev()", "destructor");

    this->Clear();
}

template <class OperatorType, class VectorType, typename ValueType>
void AIChebyshev<OperatorType, VectorType, ValueType>::Print(void) const
{
    LOG_INFO("Approximate Inverse Chebyshev(" << this->p_ << ") preconditioner");

    if(this->build_ == true)
    {
        LOG_INFO("AI matrix nnz = " << this->AIChebyshev_.GetNnz());
    }
}

template <class OperatorType, class VectorType, typename ValueType>
void AIChebyshev<OperatorType, VectorType, ValueType>::Set(int p,
                                                           ValueType lambda_min,
                                                           ValueType lambda_max)
{
    log_debug(this, "AIChebyshev::Set()", p, lambda_min, lambda_max);

    assert(p > 0);
    assert(lambda_min != static_cast<ValueType>(0));
    assert(lambda_max != static_cast<ValueType>(0));
    assert(this->build_ == false);

    this->p_          = p;
    this->lambda_min_ = lambda_min;
    this->lambda_max_ = lambda_max;
}

template <class OperatorType, class VectorType, typename ValueType>
void AIChebyshev<OperatorType, VectorType, ValueType>::Build(void)
{
    log_debug(this, "AIChebyshev::Build()", this->build_, " #*# begin");

    if(this->build_ == true)
    {
        this->Clear();
    }

    assert(this->build_ == false);
    this->build_ = true;

    assert(this->op_ != NULL);

    this->AIChebyshev_.CloneFrom(*this->op_);

    ValueType q = (static_cast<ValueType>(1) - sqrt(this->lambda_min_ / this->lambda_max_)) /
                  (static_cast<ValueType>(1) + sqrt(this->lambda_min_ / this->lambda_max_));
    ValueType c = static_cast<ValueType>(1) / sqrt(this->lambda_min_ * this->lambda_max_);

    // Shifting
    // Z = 2/(beta-alpha) [A-(beta+alpha)/2]
    OperatorType Z;
    Z.CloneFrom(*this->op_);

    Z.AddScalarDiagonal(static_cast<ValueType>(-1) * (this->lambda_max_ + this->lambda_min_) /
                        (static_cast<ValueType>(2)));
    Z.ScaleDiagonal(static_cast<ValueType>(2) / (this->lambda_max_ - this->lambda_min_));

    // Chebyshev formula/series
    // ai = I c_0 / 2 + sum c_k T_k
    // Tk = 2 Z T_k-1 - T_k-2

    // 1st term
    // T_0 = I
    // ai = I c_0 / 2
    this->AIChebyshev_.AddScalarDiagonal(c / static_cast<ValueType>(2));

    OperatorType Tkm2;
    Tkm2.CloneFrom(Z);
    // 2nd term
    // T_1 = Z
    // ai = ai + c_1 Z
    c = c * static_cast<ValueType>(-1) * q;
    this->AIChebyshev_.MatrixAdd(Tkm2, static_cast<ValueType>(1), c, true);

    // T_2 = 2*Z*Z - I
    // + c (2*Z*Z - I)
    OperatorType Tkm1;
    Tkm1.CloneBackend(*this->op_);
    Tkm1.MatrixMult(Z, Z);
    Tkm1.Scale(static_cast<ValueType>(2));
    Tkm1.AddScalarDiagonal(static_cast<ValueType>(-1));

    c = c * static_cast<ValueType>(-1) * q;
    this->AIChebyshev_.MatrixAdd(Tkm1, static_cast<ValueType>(1), c, true);

    // T_k = 2 Z T_k-1 - T_k-2
    OperatorType Tk;
    Tk.CloneBackend(*this->op_);

    for(int i = 2; i <= this->p_; ++i)
    {
        Tk.MatrixMult(Z, Tkm1);
        Tk.MatrixAdd(Tkm2, static_cast<ValueType>(2), static_cast<ValueType>(-1), true);

        c = c * static_cast<ValueType>(-1) * q;
        this->AIChebyshev_.MatrixAdd(Tk, static_cast<ValueType>(1), c, true);

        if(i + 1 <= this->p_)
        {
            Tkm2.CloneFrom(Tkm1);
            Tkm1.CloneFrom(Tk);
        }
    }

    log_debug(this, "AIChebyshev::Build()", this->build_, " #*# end");
}

template <class OperatorType, class VectorType, typename ValueType>
void AIChebyshev<OperatorType, VectorType, ValueType>::Clear(void)
{
    log_debug(this, "AIChebyshev::Clear()", this->build_);

    this->AIChebyshev_.Clear();
    this->build_ = false;
}

template <class OperatorType, class VectorType, typename ValueType>
void AIChebyshev<OperatorType, VectorType, ValueType>::MoveToHostLocalData_(void)
{
    log_debug(this, "AIChebyshev::MoveToHostLocalData_()", this->build_);

    this->AIChebyshev_.MoveToHost();
}

template <class OperatorType, class VectorType, typename ValueType>
void AIChebyshev<OperatorType, VectorType, ValueType>::MoveToAcceleratorLocalData_(void)
{
    log_debug(this, "AIChebyshev::MoveToAcceleratorLocalData_()", this->build_);

    this->AIChebyshev_.MoveToAccelerator();
}

template <class OperatorType, class VectorType, typename ValueType>
void AIChebyshev<OperatorType, VectorType, ValueType>::Solve(const VectorType& rhs, VectorType* x)
{
    log_debug(this, "AIChebyshev::Solve()", " #*# begin");

    assert(this->build_ == true);
    assert(x != NULL);
    assert(x != &rhs);

    this->AIChebyshev_.Apply(rhs, x);

    log_debug(this, "AIChebyshev::Solve()", " #*# end");
}

template <class OperatorType, class VectorType, typename ValueType>
FSAI<OperatorType, VectorType, ValueType>::FSAI()
{
    log_debug(this, "FSAI::FSAI()", "default constructor");

    this->op_mat_format_      = false;
    this->precond_mat_format_ = CSR;

    this->matrix_power_     = 1;
    this->external_pattern_ = false;
    this->matrix_pattern_   = NULL;
}

template <class OperatorType, class VectorType, typename ValueType>
FSAI<OperatorType, VectorType, ValueType>::~FSAI()
{
    log_debug(this, "FSAI::~FSAI()", "destructor");

    this->Clear();
    this->matrix_pattern_ = NULL;
}

template <class OperatorType, class VectorType, typename ValueType>
void FSAI<OperatorType, VectorType, ValueType>::Print(void) const
{
    LOG_INFO("Factorized Sparse Approximate Inverse preconditioner");

    if(this->build_ == true)
    {
        LOG_INFO("FSAI matrix nnz = " << this->FSAI_L_.GetNnz() + this->FSAI_LT_.GetNnz() -
                                             this->FSAI_L_.GetM());
    }
}

template <class OperatorType, class VectorType, typename ValueType>
void FSAI<OperatorType, VectorType, ValueType>::Set(int power)
{
    log_debug(this, "FSAI::Set()", power);

    assert(this->build_ == false);
    assert(power > 0);

    this->matrix_power_ = power;
}

template <class OperatorType, class VectorType, typename ValueType>
void FSAI<OperatorType, VectorType, ValueType>::Set(const OperatorType& pattern)
{
    log_debug(this, "FSAI::Set()", "");

    assert(this->build_ == false);

    this->matrix_pattern_ = &pattern;
}

template <class OperatorType, class VectorType, typename ValueType>
void FSAI<OperatorType, VectorType, ValueType>::Build(void)
{
    log_debug(this, "FSAI::Build()", this->build_, " #*# begin");

    if(this->build_ == true)
    {
        this->Clear();
    }

    assert(this->build_ == false);
    this->build_ = true;

    assert(this->op_ != NULL);

    this->FSAI_L_.CloneFrom(*this->op_);
    this->FSAI_L_.FSAI(this->matrix_power_, this->matrix_pattern_);

    this->FSAI_LT_.CloneFrom(this->FSAI_L_);
    this->FSAI_LT_.Transpose();

    this->t_.CloneBackend(*this->op_);
    this->t_.Allocate("temporary", this->op_->GetM());

    if(this->op_mat_format_ == true)
    {
        this->FSAI_L_.ConvertTo(this->precond_mat_format_);
        this->FSAI_LT_.ConvertTo(this->precond_mat_format_);
    }
}

template <class OperatorType, class VectorType, typename ValueType>
void FSAI<OperatorType, VectorType, ValueType>::Clear(void)
{
    log_debug(this, "FSAI::Clear()", this->build_);

    if(this->build_ == true)
    {
        this->FSAI_L_.Clear();
        this->FSAI_LT_.Clear();

        this->t_.Clear();

        this->op_mat_format_      = false;
        this->precond_mat_format_ = CSR;

        this->build_ = false;
    }

    log_debug(this, "FSAI::Build()", this->build_, " #*# end");
}

template <class OperatorType, class VectorType, typename ValueType>
void FSAI<OperatorType, VectorType, ValueType>::SetPrecondMatrixFormat(unsigned int mat_format)
{
    log_debug(this, "FSAI::SetPrecondMatrixFormat()", mat_format);

    this->op_mat_format_      = true;
    this->precond_mat_format_ = mat_format;
}

template <class OperatorType, class VectorType, typename ValueType>
void FSAI<OperatorType, VectorType, ValueType>::MoveToHostLocalData_(void)
{
    log_debug(this, "FSAI::MoveToHostLocalData_()", this->build_);

    this->FSAI_L_.MoveToHost();
    this->FSAI_LT_.MoveToHost();

    this->t_.MoveToHost();
}

template <class OperatorType, class VectorType, typename ValueType>
void FSAI<OperatorType, VectorType, ValueType>::MoveToAcceleratorLocalData_(void)
{
    log_debug(this, "FSAI::MoveToAcceleratorLocalData_()", this->build_);

    this->FSAI_L_.MoveToAccelerator();
    this->FSAI_LT_.MoveToAccelerator();

    this->t_.MoveToAccelerator();
}

template <class OperatorType, class VectorType, typename ValueType>
void FSAI<OperatorType, VectorType, ValueType>::Solve(const VectorType& rhs, VectorType* x)
{
    log_debug(this, "FSAI::Solve()", " #*# begin");

    assert(this->build_ == true);
    assert(x != NULL);
    assert(x != &rhs);

    this->FSAI_L_.Apply(rhs, &this->t_);
    this->FSAI_LT_.Apply(this->t_, x);

    log_debug(this, "FSAI::Solve()", " #*# end");
}

template <class OperatorType, class VectorType, typename ValueType>
SPAI<OperatorType, VectorType, ValueType>::SPAI()
{
    log_debug(this, "SPAI::SPAI()", "default constructor");

    this->op_mat_format_      = false;
    this->precond_mat_format_ = CSR;
}

template <class OperatorType, class VectorType, typename ValueType>
SPAI<OperatorType, VectorType, ValueType>::~SPAI()
{
    log_debug(this, "SPAI::~SPAI()", "destructor");

    this->Clear();
}

template <class OperatorType, class VectorType, typename ValueType>
void SPAI<OperatorType, VectorType, ValueType>::Print(void) const
{
    LOG_INFO("SParse Approximate Inverse preconditioner");

    if(this->build_ == true)
    {
        LOG_INFO("SPAI matrix nnz = " << this->SPAI_.GetNnz());
    }
}

template <class OperatorType, class VectorType, typename ValueType>
void SPAI<OperatorType, VectorType, ValueType>::Build(void)
{
    log_debug(this, "SPAI::Build()", this->build_, " #*# begin");

    if(this->build_ == true)
    {
        this->Clear();
    }

    assert(this->build_ == false);
    this->build_ = true;

    assert(this->op_ != NULL);

    this->SPAI_.CloneFrom(*this->op_);
    this->SPAI_.SPAI();

    if(this->op_mat_format_ == true)
    {
        this->SPAI_.ConvertTo(this->precond_mat_format_);
    }

    log_debug(this, "SPAI::Build()", this->build_, " #*# end");
}

template <class OperatorType, class VectorType, typename ValueType>
void SPAI<OperatorType, VectorType, ValueType>::Clear(void)
{
    log_debug(this, "SPAI::Clear()", this->build_);

    if(this->build_ == true)
    {
        this->SPAI_.Clear();

        this->op_mat_format_      = false;
        this->precond_mat_format_ = CSR;

        this->build_ = false;
    }
}

template <class OperatorType, class VectorType, typename ValueType>
void SPAI<OperatorType, VectorType, ValueType>::SetPrecondMatrixFormat(unsigned int mat_format)
{
    log_debug(this, "SPAI::SetPrecondMatrixFormat()", mat_format);

    this->op_mat_format_      = true;
    this->precond_mat_format_ = mat_format;
}

template <class OperatorType, class VectorType, typename ValueType>
void SPAI<OperatorType, VectorType, ValueType>::MoveToHostLocalData_(void)
{
    log_debug(this, "SPAI::MoveToHostLocalData_()", this->build_);

    this->SPAI_.MoveToHost();
}

template <class OperatorType, class VectorType, typename ValueType>
void SPAI<OperatorType, VectorType, ValueType>::MoveToAcceleratorLocalData_(void)
{
    log_debug(this, "SPAI::MoveToAcceleratorLocalData_()", this->build_);

    this->SPAI_.MoveToAccelerator();
}

template <class OperatorType, class VectorType, typename ValueType>
void SPAI<OperatorType, VectorType, ValueType>::Solve(const VectorType& rhs, VectorType* x)
{
    log_debug(this, "SPAI::Solve()", " #*# begin");

    assert(this->build_ == true);
    assert(x != NULL);
    assert(x != &rhs);

    this->SPAI_.Apply(rhs, x);

    log_debug(this, "SPAI::Solve()", " #*# end");
}

template <class OperatorType, class VectorType, typename ValueType>
TNS<OperatorType, VectorType, ValueType>::TNS()
{
    log_debug(this, "TNS::TNS()", "default constructor");

    this->op_mat_format_      = false;
    this->precond_mat_format_ = CSR;

    this->impl_ = true;
}

template <class OperatorType, class VectorType, typename ValueType>
TNS<OperatorType, VectorType, ValueType>::~TNS()
{
    log_debug(this, "TNS::~TNS()", "destructor");

    this->Clear();
}

template <class OperatorType, class VectorType, typename ValueType>
void TNS<OperatorType, VectorType, ValueType>::Print(void) const
{
    LOG_INFO("Truncated Neumann Series (TNS) Preconditioner");

    if(this->build_ == true)
    {
        if(this->impl_ == true)
        {
            LOG_INFO("Implicit TNS L matrix nnz = " << this->L_.GetNnz());
        }
        else
        {
            LOG_INFO("Explicit TNS matrix nnz = " << this->TNS_.GetNnz());
        }
    }
}

template <class OperatorType, class VectorType, typename ValueType>
void TNS<OperatorType, VectorType, ValueType>::Set(bool imp)
{
    assert(this->build_ != true);

    this->impl_ = imp;
}

template <class OperatorType, class VectorType, typename ValueType>
void TNS<OperatorType, VectorType, ValueType>::Build(void)
{
    log_debug(this, "TNS::Build()", this->build_, " #*# begin");

    if(this->build_ == true)
    {
        this->Clear();
    }

    assert(this->build_ == false);
    this->build_ = true;

    assert(this->op_ != NULL);

    if(this->impl_ == true)
    {
        // implicit computation

        this->L_.CloneBackend(*this->op_);
        this->LT_.CloneBackend(*this->op_);

        this->tmp1_.CloneBackend(*this->op_);
        this->tmp2_.CloneBackend(*this->op_);
        this->Dinv_.CloneBackend(*this->op_);

        this->op_->ExtractInverseDiagonal(&this->Dinv_);

        this->op_->ExtractL(&this->L_, false);
        this->L_.DiagonalMatrixMultR(this->Dinv_);

        this->LT_.CloneFrom(this->L_);
        this->LT_.Transpose();

        this->tmp1_.Allocate("tmp1 vec for TNS", this->op_->GetM());
        this->tmp2_.Allocate("tmp2 vec for TNS", this->op_->GetM());
    }
    else
    {
        // explicit computation

        OperatorType K, KT;

        this->L_.CloneBackend(*this->op_);
        this->Dinv_.CloneBackend(*this->op_);
        this->TNS_.CloneBackend(*this->op_);

        K.CloneBackend(*this->op_);
        KT.CloneBackend(*this->op_);

        this->op_->ExtractInverseDiagonal(&this->Dinv_);

        // get the diagonal but flash them to zero
        // keep the structure
        this->op_->ExtractL(&this->L_, true);
        this->L_.ScaleDiagonal(static_cast<ValueType>(0));
        this->L_.DiagonalMatrixMultR(this->Dinv_);

        K.MatrixMult(this->L_, this->L_);

        // add -I
        this->L_.AddScalarDiagonal(static_cast<ValueType>(-1));

        K.MatrixAdd(this->L_,
                    static_cast<ValueType>(1),  // for L^2
                    static_cast<ValueType>(-1), // for (-I+L)
                    true);

        KT.CloneFrom(K);
        KT.Transpose();

        KT.DiagonalMatrixMultR(this->Dinv_);

        this->TNS_.MatrixMult(KT, K);

        K.Clear();
        KT.Clear();

        this->L_.Clear();
        this->Dinv_.Clear();
    }

    if(this->op_mat_format_ == true)
    {
        this->TNS_.ConvertTo(this->precond_mat_format_);
        this->L_.ConvertTo(this->precond_mat_format_);
        this->LT_.ConvertTo(this->precond_mat_format_);
    }

    log_debug(this, "TNS::Build()", this->build_, " #*# end");
}

template <class OperatorType, class VectorType, typename ValueType>
void TNS<OperatorType, VectorType, ValueType>::Clear(void)
{
    log_debug(this, "TNS::Clear()", this->build_);

    if(this->build_ == true)
    {
        this->TNS_.Clear();

        this->L_.Clear();
        this->LT_.Clear();
        this->Dinv_.Clear();

        this->tmp1_.Clear();
        this->tmp2_.Clear();

        this->op_mat_format_      = false;
        this->precond_mat_format_ = CSR;

        this->build_ = false;
    }
}

template <class OperatorType, class VectorType, typename ValueType>
void TNS<OperatorType, VectorType, ValueType>::SetPrecondMatrixFormat(unsigned int mat_format)
{
    log_debug(this, "TNS::SetPrecondMatrixFormat()", mat_format);

    this->op_mat_format_      = true;
    this->precond_mat_format_ = mat_format;
}

template <class OperatorType, class VectorType, typename ValueType>
void TNS<OperatorType, VectorType, ValueType>::MoveToHostLocalData_(void)
{
    log_debug(this, "TNS::MoveToHostLocalData_()", this->build_);

    this->TNS_.MoveToHost();
    this->L_.MoveToHost();
    this->LT_.MoveToHost();
    this->Dinv_.MoveToHost();
    this->tmp1_.MoveToHost();
    this->tmp2_.MoveToHost();
}

template <class OperatorType, class VectorType, typename ValueType>
void TNS<OperatorType, VectorType, ValueType>::MoveToAcceleratorLocalData_(void)
{
    log_debug(this, "TNS::MoveToAcceleratorLocalData_()", this->build_);

    this->TNS_.MoveToHost();
    this->L_.MoveToAccelerator();
    this->LT_.MoveToAccelerator();
    this->Dinv_.MoveToAccelerator();
    this->tmp1_.MoveToAccelerator();
    this->tmp2_.MoveToAccelerator();
}

template <class OperatorType, class VectorType, typename ValueType>
void TNS<OperatorType, VectorType, ValueType>::Solve(const VectorType& rhs, VectorType* x)
{
    log_debug(this, "TNS::Solve()", " #*# begin");

    assert(this->build_ == true);
    assert(x != NULL);
    assert(x != &rhs);

    if(this->impl_ == true)
    {
        // implicit

        this->L_.Apply(rhs, &this->tmp1_);
        this->L_.Apply(this->tmp1_, &this->tmp2_);
        this->tmp1_.AddScale(this->tmp2_, static_cast<ValueType>(-1));

        x->CopyFrom(rhs);
        x->AddScale(this->tmp1_, static_cast<ValueType>(-1));

        x->PointWiseMult(this->Dinv_);

        this->LT_.Apply(*x, &this->tmp1_);
        this->LT_.Apply(this->tmp1_, &this->tmp2_);

        x->ScaleAdd2(static_cast<ValueType>(1),
                     this->tmp1_,
                     static_cast<ValueType>(-1),
                     this->tmp2_,
                     static_cast<ValueType>(1));
    }
    else
    {
        // explicit

        this->TNS_.Apply(rhs, x);
    }

    //  LOG_INFO(x->Norm());

    log_debug(this, "TNS::Solve()", " #*# end");
}

template class AIChebyshev<LocalMatrix<double>, LocalVector<double>, double>;
template class AIChebyshev<LocalMatrix<float>, LocalVector<float>, float>;
#ifdef SUPPORT_COMPLEX
template class AIChebyshev<LocalMatrix<std::complex<double>>,
                           LocalVector<std::complex<double>>,
                           std::complex<double>>;
template class AIChebyshev<LocalMatrix<std::complex<float>>,
                           LocalVector<std::complex<float>>,
                           std::complex<float>>;
#endif

template class FSAI<LocalMatrix<double>, LocalVector<double>, double>;
template class FSAI<LocalMatrix<float>, LocalVector<float>, float>;
#ifdef SUPPORT_COMPLEX
template class FSAI<LocalMatrix<std::complex<double>>,
                    LocalVector<std::complex<double>>,
                    std::complex<double>>;
template class FSAI<LocalMatrix<std::complex<float>>,
                    LocalVector<std::complex<float>>,
                    std::complex<float>>;
#endif

template class SPAI<LocalMatrix<double>, LocalVector<double>, double>;
template class SPAI<LocalMatrix<float>, LocalVector<float>, float>;
#ifdef SUPPORT_COMPLEX
template class SPAI<LocalMatrix<std::complex<double>>,
                    LocalVector<std::complex<double>>,
                    std::complex<double>>;
template class SPAI<LocalMatrix<std::complex<float>>,
                    LocalVector<std::complex<float>>,
                    std::complex<float>>;
#endif

template class TNS<LocalMatrix<double>, LocalVector<double>, double>;
template class TNS<LocalMatrix<float>, LocalVector<float>, float>;
#ifdef SUPPORT_COMPLEX
template class TNS<LocalMatrix<std::complex<double>>,
                   LocalVector<std::complex<double>>,
                   std::complex<double>>;
template class TNS<LocalMatrix<std::complex<float>>,
                   LocalVector<std::complex<float>>,
                   std::complex<float>>;
#endif

} // namespace rocalution
